{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image classification\n",
    "*Thomas Dooms*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the previous chapter, we discussed weight-based interpretability vs activation-based interpretability. We also explored bilinear layers and interaction matrices, a central object toward interpreting them. This chapter will cover small image classifiers trained on MNIST. Our interpretable, bilinear architecture will give us the tools necessary to understand what these models are learning. Previously, this was famously difficult; the Anthropic team once made the following statement [[ref](https://transformer-circuits.pub/2024/jan-update#mnist-sparse)].\n",
    "\n",
    "> Our failure to understand MNIST models might be seen as an example of a more general trend: mechanistic interpretability has had little success in reverse engineering MLP layers in small but real neural networks. ... perhaps small models are just pathological in some way and don't have \"crisp abstractions\"!\n",
    "\n",
    "We will show that small MNIST models actually learn very sensible structures once we understand how to look for them. This chapter is structured in a trial-and-error fashion, reflecting to some degree our process in interpreting these models. That said, let's dive right in."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting up\n",
    "\n",
    "The model we are using is a slight extension of our previous architecture. We now have an embedding matrix, followed by the same MLP as before. Little has changed beyond the model's dimensions and the fact that we are training on MNIST."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "import plotly.express as px\n",
    "\n",
    "from image import Model, MNIST\n",
    "from einops import einsum\n",
    "from kornia.augmentation import RandomGaussianNoise\n",
    "from image import plot_eigenspectrum, plot_explanation\n",
    "\n",
    "# Change this to \"cpu\" if you don't have a GPU\n",
    "device = \"cuda:0\"\n",
    "\n",
    "# Instantiate the model with the default configuration\n",
    "model = Model.from_config(epochs=20).to(device)\n",
    "\n",
    "# Load the MNIST dataset (directly on to a device for efficiency)\n",
    "train, test = MNIST(train=True, device=device), MNIST(train=False, device=device)\n",
    "metrics = model.fit(train, test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attempt 1: naive approach\n",
    "Let's simply re-use the code from the previous chapter, with a slight modification to account for the embedding matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# .w_l and .w_r would return the left matrices for each layer. We only have one layer, so we index at 0.\n",
    "l, r = model.w_l[0], model.w_r[0]\n",
    "\n",
    "# We compute the same as previously, now taking into account the additional embedding matrix\n",
    "b = einsum(model.w_u, l, r, model.w_e, model.w_e, \"cls out, out hid1, out hid2, hid1 in1, hid2 in2 -> cls in1 in2\")\n",
    "b = 0.5 * (b + b.mT)\n",
    "\n",
    "px.imshow(b[0].cpu(), color_continuous_midpoint=0, color_continuous_scale=\"RdBu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We get ... something. It's not random, but it's rather hard to determine what is happening. The reason is simple: viewing images as vectors loses spatial information. This is aggravated by considering the interactions between these vectors, resulting in doubly distorted information. Ideally, we'd somehow turn this interaction matrix into a sum of vectors, which all capture meaningful information about the interactions. That way, we can actually visualize them."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attempt 2: decomposition into components\n",
    "There are many matrix decompositions to achieve this. Due to the symmetric nature of our interaction matrices, we can use an especially nice decomposition: the eigendecomposition. Informally, the eigendecomposition converts a matrix into a collection of eigenvectors and eigenvalues. The eigenvector represents some matrix component, and the eigenvalue describes its importance. This is especially nice since we can order eigenvectors by relevance and analyze however many we wish, knowing we are doing so optimally.\n",
    "\n",
    "So, let's implement this approach. First, we convert the MLP part of the model into its interaction matrices. Then, we perform the eigendecomposition on this. There are a few noteworthy details to this.\n",
    "- This function only considers the last two dimensions for the decomposition; it views all others as a batch. Hence, it computes the decomposition for each output in our case, which is what we want.\n",
    "- There are some variants of the eigendecomposition in PyTorch; since we know our matrix is symmetric, this one is best.\n",
    "\n",
    "Lastly, we project the eigenvectors to the 'input space' and visualize them. For now, we simply visualize the most important (positive) eigenvector, which is stored at the last index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the bilinear layer into the left and right components\n",
    "l, r = model.w_lr[0].unbind()\n",
    "\n",
    "# Compute the third-order (bilinear) tensor\n",
    "b = einsum(model.w_u, l, r, \"cls out, out in1, out in2 -> cls in1 in2\")\n",
    "\n",
    "# Symmetrize the tensor\n",
    "b = 0.5 * (b + b.mT)\n",
    "\n",
    "# Perform the eigendecomposition\n",
    "vals, vecs = torch.linalg.eigh(b)\n",
    "\n",
    "# Project the eigenvectors back to the input space\n",
    "vecs = einsum(vecs, model.w_e, \"cls emb comp, emb inp -> cls comp inp\")\n",
    "\n",
    "# Take the class (cls) for digit 0 and the last component (comp), which indicates the most positive eigenvalue\n",
    "px.imshow(vecs[0, -1].view(28, 28).cpu(), color_continuous_midpoint=0, color_continuous_scale=\"RdBu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "While we see some spatial structure, this is still not what we expect (or what the paper shows). As we will see, this is actually a feature, not a bug! Essentially, since we did not regularize our model during training, it overfits. Specifically, it's picking up on some outlying pixels that probably don't occur often and uses those to memorize the class, hence their importance. This is not really what we want; we want to understand what a 'generalizing' model is doing."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "=== Attempt 3: adding training regularization\n",
    "Our ``.fit()`` function accepts augmentations (a small function that is applied to all inputs). Common augmentation (or regularization) includes rotation and translation. Experimentally, we found that these don't always work for MNIST since the model can still memorize certain outlying patterns. We did find input noise (adding some noise to each pixel) to work quite well, so let's now retrain a model using that.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model.from_config(epochs=20).to(device)\n",
    "metrics = model.fit(train, test, RandomGaussianNoise(std=0.4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vals, vecs = model.decompose()\n",
    "px.imshow(vecs[0, -1].view(28, 28).cpu(), color_continuous_midpoint=0, color_continuous_scale=\"RdBu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we get discernable digits. With this setup, we can recreate the plots from the paper (after removing some visual bloat). We can get slightly cleaner features by training longer or adding slightly more noise. \n",
    "\n",
    "In Appendix B of the paper, we perform a broader study of how different types of regularization and augmentation impact learned features. As stated above, many common augmentations don't remove overfitting. Instead of single pixels, it overfits based on small regions (according to the augmentation)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The main visualization code\n",
    "fig = px.imshow(vecs[1:6, -1].view(-1, 28, 28).cpu(), color_continuous_midpoint=0, color_continuous_scale=\"RdBu\", facet_col=0)\n",
    "\n",
    "# Hide all the abundant information to get a clean plot\n",
    "fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)\n",
    "fig.update_layout(showlegend=False).update_coloraxes(showscale=False)\n",
    "fig.for_each_annotation(lambda a: a.update(text=\"\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiple eigenvectors per digit.\n",
    "\n",
    "We've now recreated the result from the paper, but what does it actually mean? Let's write down the actual decomposition: $\\sum_i \\lambda_i (\\textbf{v}_i \\textbf{x}_i)^2$ where $\\lambda_i$ are the eigenvalues and $\\textbf{v}_i$ are the eigenvectors. We call $(\\textbf{v}_i \\textbf{x})^2$ the activation (as this is analogous to a normal activation, just in a different 'feature' space).\n",
    "\n",
    "Interpreting eigenvectors is not like interpreting a heatmap where blue is a positive contribution and red is a negative contribution. Rather, the colors represent a form of constructive and destructive interference:\n",
    "- If the input matches *none* of the colors, the activation will be *low*.\n",
    "- If the input matches *either* color, the activation will be *high*.\n",
    "- If the input matches *both* colors, the activation will be *low*.\n",
    "\n",
    "This is somewhat analogous to an XOR-gate. Also, notice that the activation is either low or high, never negative (due to the square). The sign of the eigenvalue fully determines whether a feature will contribute positively or negatively. Eigenvectors with positive eigenvalues generally represent specific parts of digits that we expect (and have seen until now). Negative eigenvectors are somewhat less intuitive; they generally represent specific parts that are most destructive towards classifying a given digit, which is often hard to understand. Consider the digit 4; if one were to add a vertical line between the top parts, it would most likely be a 9. Hence, it makes sense for the model to strongly negatively match this region.\n",
    "\n",
    "Let's visualize what these eigenvalues look like. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the eigenvalues for all digits\n",
    "px.line(vals.cpu().T, template=\"plotly_white\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that in all cases, there are a few important eigenvalues (about 10 on each side).\n",
    "Plotting these important vectors isn't difficult, but let's use a fancy function that puts everything nicely side-by-side."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_eigenspectrum(model, digit=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using eigenvectors as explanations.\n",
    "The features we get are nice, but are they actually good at explaining why a model makes a certain prediction? We can check this by measuring the activations for a given input and then plotting the features that activated most strongly. We plot both the top and bottom activating eigenvector for the three digits that had the highest logits. We see that, generally, only very few eigenvectors contribute significantly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_explanation(model, train.x[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Future directions.\n",
    "\n",
    "The decomposition above is only one way to analyze the interaction matrices. A good decomposition satisfies two criteria: short and interpretable. We satisfy the second criterion by considering features per digit, which are generally easier to understand. However, it may miss shared features, specific strokes that occur in multiple digits, which the model leverages. In that sense, our decomposition could be shorter. This becomes especially important when scaling up. It's currently somewhat unclear what the best way is to achieve both criteria. One possibility is dictionary learning, not SAEs, but rather 'Ye Olde' Lasso Regression.\n",
    "\n",
    "Properly scaling this method remains an open question. Our proposed method decomposes the model top-down (output to input). At each step, we compute eigenvectors and use those as output directions for the next layers. One could perform a sort of beam search to extract the most important global directions. However, this leverages no shared features in the model, which is suboptimal. Clustering could possibly help in this."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
